Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deterministic error propagation for distributed (training) tasks #5598

Merged
merged 10 commits into from
Nov 18, 2024

Conversation

fg91
Copy link
Member

@fg91 fg91 commented Jul 27, 2024

Why are the changes needed?

See introduction of RFC document.

@fg91 fg91 added the rfc A label for RFC issues label Jul 27, 2024
@fg91 fg91 self-assigned this Jul 27, 2024
Signed-off-by: Fabio M. Graetz, Ph.D. <[email protected]>
Copy link

codecov bot commented Jul 27, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 36.21%. Comparing base (5b0d787) to head (6b55c00).
Report is 213 commits behind head on master.

Additional details and impacted files
@@             Coverage Diff             @@
##           master    #5598       +/-   ##
===========================================
- Coverage   60.97%   36.21%   -24.77%     
===========================================
  Files         794     1303      +509     
  Lines       51488   109568    +58080     
===========================================
+ Hits        31397    39683     +8286     
- Misses      17199    65765    +48566     
- Partials     2892     4120     +1228     
Flag Coverage Δ
unittests-datacatalog 51.37% <ø> (-17.95%) ⬇️
unittests-flyteadmin 55.63% <ø> (-3.06%) ⬇️
unittests-flytecopilot 12.17% <ø> (-5.62%) ⬇️
unittests-flytectl 62.21% <ø> (-5.82%) ⬇️
unittests-flyteidl 7.12% <ø> (-71.93%) ⬇️
unittests-flyteplugins 53.35% <ø> (-8.50%) ⬇️
unittests-flytepropeller 41.76% <ø> (-15.54%) ⬇️
unittests-flytestdlib 55.35% <ø> (-10.25%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.


🚨 Try these New Features:

* As a Flyte user trying to understand why a distributed training task failed, I currently cannot rely on the error reported in the Flyte Console (UI) being the root cause error.
* Instead, I have to search the logs of each worker pod. For distributed training jobs with dozens or even hundreds of worker pods, this can be tedious.
* (Current remedies include combining all worker pods in stackdriver logs using a wildcard in the pod name and then filtering by severity.)
* As a Flyte user marking specific errors that can occur in distributed training jobs as retriable (using a `FlyteRecoverableException`), I want Flyte to deterministically determine the root cause error so that the retry behaviour does not suffer from a race condition.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Love this one. It is subtle, but in a typical PyTorch job, the first error is an instance of ChildError, and Flyte plugin for PyTorch handles that properly by checking if the root cause is a recoverable exception. The latter errors are often different, such as Rendezvous errors, and Flyte treats them as non-recoverable. With the current behavior of latter pod errors taking over, we end up treating these as non-recoverable.

@fg91 fg91 requested a review from bgedik July 31, 2024 19:54
Comment on lines 63 to 65
* We could add a `MultipleErrorFiles` property to `PluginProperties` (see https://github.com/flyteorg/flyte/blob/4514860cf56ba62717f6c207f269410a8c1a5461/flyteplugins/go/tasks/pluginmachinery/k8s/plugin.go#L34). The PyTorch plugin, for instance, would then pass `true` for `MultipleErrorFiles` [here](https://github.com/flyteorg/flyte/blob/4514860cf56ba62717f6c207f269410a8c1a5461/flyteplugins/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go#L31).

Currently, [here](https://github.com/flyteorg/flyte/blob/4514860cf56ba62717f6c207f269410a8c1a5461/flytepropeller/pkg/controller/nodes/task/k8s/plugin_manager.go#L290) in the plugin manager, where we call `NewRemoteFileOutputReader`, we do have access to `e.plugin`, and thus to `PluginProperties` and could make use of that information to instantiate another output reader.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 to this.

If we feel that MultiErrorFileRemoteFileOutputReader ends up being used more widely in plugins we can "bump" it up in the configuration of plugins.

To transport this information from `pyflyte-execute` to flytepropeller, we propose to add an additional field `pod_name` (or `container_name`) to `message ContainerError`.

Open question:
* Where does flytepropeller add this info for it to be displayed as part of the error in the UI?
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pvditt , let's set some time to clarify how this is done currently.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@eapolinario @fg91 apologies for just now seeing this.

DemystifyFailure and DemystifyPending get the messages from the pods.

@wild-endeavor
Copy link
Contributor

Flytekit related

  • each pod uploads a unique error file name.
    • this is good yes, but i'm not sure it's necessary to pin it to the pod name. when it makes sense, flyte will open up its abstraction layer and expose the underlying k8s construct (like pod template), but i think in this case maybe it doesn't make sense to. the file name that propeller passes to the plugin can still be the pod name, esp if that's helpful, but maybe we can use a different env var?
  • the FLYTE_INTERNAL_ERROR_PROPAGATION - other than signaling to flytekit to use a special error file name, will this do anything else?
  • I think adding timestamps to errors is fine, and if that's the only signal a plugin has to go off on, then okay, but in general i don't like relying on timestamps. wall clock time can be off.
  • Should we put the new files in an errors/ folder as a sibling to the error.pb today?

So broadly speaking there's a map/reduce problem with error docs. With flytekit changes, we can take care of the map part pretty easily. On the reduce side...

  • Are there cases where there can be one special Python worker that can iterate through all the child errors and merge them into one error document? (I'm assuming the answer is even if there are, this is rare and more likely it'll need to be handled by the backend plugin code)
  • On the back-end side, there's two potential kinds of plugins - K8s plugins and non-K8s plugins.
    • For non-K8s plugins, we'd probably just ask plugin writers to do this as a special case of Handle() right? Scan output path for multiple error pb files, and somehow collate or merge them into one.
    • For K8s plugins, we're saying that it feels too weird to do it as part of GetTaskPhase() right? I kinda agree.

@fg91 fg91 marked this pull request as draft August 29, 2024 16:14
@fg91
Copy link
Member Author

fg91 commented Aug 29, 2024

FYI I made the PR a draft again and moved the RFC back to "new" in the Kanban. I don't have a full picture yet of how we communicate from propeller to admin/the UI in which of the workers the error occurred. Once I understood the details, I'll update the RFC and mark it as ready for review again.

@fg91
Copy link
Member Author

fg91 commented Aug 29, 2024

Thank you for your review @wild-endeavor ! 🙇

  • each pod uploads a unique error file name.

    • this is good yes, but i'm not sure it's necessary to pin it to the pod name. when it makes sense, flyte will open up its abstraction layer and expose the underlying k8s construct (like pod template), but i think in this case maybe it doesn't make sense to. the file name that propeller passes to the plugin can still be the pod name, esp if that's helpful, but maybe we can use a different env var?

Ultimately we want to show in the UI which worker pod failed so that one immediately knows which pod's logs to check. I was first thinking that propeller takes this information from the error file name. But now, thinking about it, maybe it's better if we properly persist the pod name in message ContainerError? The file name could then just contain a random UUID to avoid collisions. That leaves the question how the pod entrypoint knows the pod name. I see two options: 1) either from the HOSTNAME env var or 2) from a potential FLYTE_INTERNAL_POD_NAME env var that would be set via the k8s downward api. Option 1 would be simpler but I'm not 100% sure whether there are edge cases where the HOSTNAME is not equal to the pod name.

  • the FLYTE_INTERNAL_ERROR_PROPAGATION - other than signaling to flytekit to use a special error file name, will this do anything else?

In the RFC we discuss that in the future there might be potential other strategies to determine the root cause error other than earliest timestamp. I could imagine that the pod entrypoint might have to do different things based on this strategy. For now, FLYTE_INTERNAL_ERROR_PROPAGATION=earliest would tell the pod entrypoint that is has to include the timestamp in ContainerError.
In theory we could not distinguish between strategies (with a single case for now) but just set FLYTE_INTERNAL_MULTI_ERROR.

  • I think adding timestamps to errors is fine, and if that's the only signal a plugin has to go off on, then okay, but in general i don't like relying on timestamps. wall clock time can be off.

I agree with the criticism that time is not a reliable indicator but it is the strategy that torch distributed elastic launch uses within a local worker group (processes running in a single pod) to determine which of the processes in the pod died first. In this sense, we are just extending the mechanism to multiple nodes.
I would argue that we treat error propagation in a best-effort manner here which will improve the situation drastically given that currently there is a race condition which favors the non-root cause errors. If sometimes we still don't identify the right root cause errors, users can still combine logs of multiple pods and search all of them.

  • Should we put the new files in an errors/ folder as a sibling to the error.pb today?

I'm open to both, if you have a preference, I'll put that into the rfc.

So broadly speaking there's a map/reduce problem with error docs. With flytekit changes, we can take care of the map part pretty easily. On the reduce side...

  • Are there cases where there can be one special Python worker that can iterate through all the child errors and merge them into one error document? (I'm assuming the answer is even if there are, this is rare and more likely it'll need to be handled by the backend plugin code)

I can't come up with a case where we would want to handle this in python/tasks. For torch distributed, communication between workers to aggregate errors would require them to rendezvous but this is very likely not possible anymore after the worker group crashed.

  • On the back-end side, there's two potential kinds of plugins - K8s plugins and non-K8s plugins.

    • For non-K8s plugins, we'd probably just ask plugin writers to do this as a special case of Handle() right? Scan output path for multiple error pb files, and somehow collate or merge them into one.

For this RFC, I'm only thinking about K8s plugins to be honest. I would expect authors of non-k8s plugins to do this in Handle(), yes.

  • For K8s plugins, we're saying that it feels too weird to do it as part of GetTaskPhase() right? I kinda agree.

Yes, definitely, plugins don't read error files in GetTaskPhase() currently afaik and I'd like to not blur that boundary.

Fabio Grätz added 2 commits September 9, 2024 23:34
@fg91 fg91 marked this pull request as ready for review September 9, 2024 21:37
@fg91 fg91 requested a review from bgedik September 9, 2024 21:37

Currently, [here](https://github.com/flyteorg/flyte/blob/4514860cf56ba62717f6c207f269410a8c1a5461/flytepropeller/pkg/controller/nodes/task/k8s/plugin_manager.go#L290) in the plugin manager, upon completion of a node execution, a new [`RemoteFileOutputReader`](https://github.com/flyteorg/flyte/blob/d6da838627d57cd27d60beea004e974ce1fb3ca5/flyteplugins/go/tasks/pluginmachinery/ioutils/remote_file_output_reader.go#L14) is constructed which is responsible for reading the error file uploaded to blob storage. This `RemoteFileOutputReader` implements the [`OutputReader` interface](https://github.com/flyteorg/flyte/blob/1e54d21c4d4ee74245f799a57b4bb8a5534e8368/flyteplugins/go/tasks/pluginmachinery/io/iface.go#L32).

We propose to implement a new `MultiErrorFileRemoteFileOutputReader` which (for future flexibility) can be configured with the different strategies we define. Initially, the only available strategy will be `"earliest"` which the RFC authors aim to use for the kubeflow pytorch plugin. This output reader will search for all error files in the `/errors` folder under the raw output prefix and aggregate the error as specified by the strategy.
Copy link
Contributor

@bgedik bgedik Sep 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of a MultiErrorFileRemoteFileOutputReader class, why not have RemoteFileOutputReader take an ErrorAggregationStrategy? Default is the current behavior.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The RemoteFileOutputReader's main goal is to read outputs. We can have multiple classes to extract the errors, which would be an internal detail to RemoteFileOutputReader.


We propose to implement a new `MultiErrorFileRemoteFileOutputReader` which (for future flexibility) can be configured with the different strategies we define. Initially, the only available strategy will be `"earliest"` which the RFC authors aim to use for the kubeflow pytorch plugin. This output reader will search for all error files in the `/errors` folder under the raw output prefix and aggregate the error as specified by the strategy.

If in [the plugin manager](https://github.com/flyteorg/flyte/blob/4514860cf56ba62717f6c207f269410a8c1a5461/flytepropeller/pkg/controller/nodes/task/k8s/plugin_manager.go#L290) the respective plugin is found to configure an error aggregation strategy other than `Default`, we instantiate such a `MultiErrorFileRemoteFileOutputReader` reader (instead of the existing `RemoteFileOutputReader`) and configure it with the respective strategy.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How does the plugin manager get access to the plugin properties?

Copy link
Member Author

@fg91 fg91 Sep 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here, we can call e.plugin.GetProperties()

@davidmirror-ops
Copy link
Contributor

davidmirror-ops commented Sep 26, 2024

09/26/2024 Contributors sync notes: Fabio provided a summary of the proposal, @pvditt to please review.

@fg91
Copy link
Member Author

fg91 commented Oct 23, 2024

@wild-endeavor and @EngHabu had these questions (I moved these questions here into the RFC to not disperse the discussion over multiple PRs, hope that's ok for you):

had a quick chat with @EngHabu about this... few questions

  • Is there nothing in the pytorch operator that will help with the aggregation across workers?
  • Is there a possible race condition in the logs? Like worker 1 experiences an error, but then gets delayed writing to s3, so worker 2 writes its error log first, and then worker 1's log appears later in the series.
  • is it possible to just head for the error.pb first, and only do a list if that is not there? then it won't waste the list.
    anything else @EngHabu?

@bgedik responded:

1/ I’ll research this and get back.
2/ The timestamp is not the timestamp of the write time but the timestamp of the exception happening. However, I don’t know when Flyte retrieves the error. So it is possible that we try to aggregate before everything is ready. I can do some code reading to understand this, but if anyone already has an answer, let me know.
3/ If we are going with the no strategy design and unify, we can certainly do that.


My replies:

  1. Torch distributed elastic has functionality to identify the first error that occurred in a local process group within a worker pod, see here. But it does not help with aggregating errors across pods. This is the information that is available:

    Events:
      Type     Reason                          Age                From                   Message
      ----     ------                          ----               ----                   -------
      Normal   SuccessfulCreatePod             27s                pytorchjob-controller  Created pod: ascvv8w42z8vvqh2drv9-n0-0-worker-0
      Warning  SettedPodTemplateRestartPolicy  27s (x2 over 27s)  pytorchjob-controller  Restart policy in pod template will be overwritten by restart policy in replica spec
      Normal   SuccessfulCreatePod             27s                pytorchjob-controller  Created pod: ascvv8w42z8vvqh2drv9-n0-0-worker-1
      Normal   SuccessfulCreateService         27s                pytorchjob-controller  Created service: ascvv8w42z8vvqh2drv9-n0-0-worker-0
      Normal   SuccessfulCreateService         27s                pytorchjob-controller  Created service: ascvv8w42z8vvqh2drv9-n0-0-worker-1
      Warning  Error                           23s (x2 over 23s)  pytorchjob-controller  Error pod ascvv8w42z8vvqh2drv9-n0-0-worker-1 container pytorch exitCode: 1 terminated message:
      Warning  Error                           21s (x3 over 23s)  pytorchjob-controller  Error pod ascvv8w42z8vvqh2drv9-n0-0-worker-0 container pytorch exitCode: 1 terminated message:
      Normal   ExitedWithCode                  21s (x3 over 23s)  pytorchjob-controller  Pod: flytesnacks-development.ascvv8w42z8vvqh2drv9-n0-0-worker-0 exited with code 1
      Normal   ExitedWithCode                  21s (x3 over 23s)  pytorchjob-controller  Pod: flytesnacks-development.ascvv8w42z8vvqh2drv9-n0-0-worker-1 exited with code 1
      Normal   PyTorchJobFailed                21s                pytorchjob-controller  PyTorchJob ascvv8w42z8vvqh2drv9-n0-0 is failed because 1 Worker replica(s) failed.
    status:
      completionTime: "2024-10-23T03:37:07Z"
      conditions:
      - lastTransitionTime: "2024-10-23T03:37:01Z"
        lastUpdateTime: "2024-10-23T03:37:01Z"
        message: PyTorchJob ascvv8w42z8vvqh2drv9-n0-0 is created.
        reason: PyTorchJobCreated
        status: "True"
        type: Created
      - lastTransitionTime: "2024-10-23T03:37:01Z"
        lastUpdateTime: "2024-10-23T03:37:01Z"
        message: PyTorchJob flytesnacks-development/ascvv8w42z8vvqh2drv9-n0-0 is running.
        reason: PyTorchJobRunning
        status: "False"
        type: Running
      - lastTransitionTime: "2024-10-23T03:37:07Z"
        lastUpdateTime: "2024-10-23T03:37:07Z"
        message: PyTorchJob ascvv8w42z8vvqh2drv9-n0-0 is failed because 1 Worker replica(s)
          failed.
        reason: PyTorchJobFailed
        status: "True"
        type: Failed
      replicaStatuses:
        Worker:
          active: 1
          failed: 1
          selector: training.kubeflow.org/job-name=ascvv8w42z8vvqh2drv9-n0-0,training.kubeflow.org/operator-name=pytorchjob-controller,training.kubeflow.org/replica-type=worker
      startTime: "2024-10-23T03:37:01Z"

    I don't see which information the training operator could provide to help with aggregating errors though. Propeller needs the ContainerError of the pod that failed with the first exception (not necessarily the one of the pod that died first but the one of the pod with the worker process of torch elastic launch that died first) to inform the user of what killed the job or to do determine whether the failure is recoverable, ...
    (The RFC mentions that "first error" is still only a best-effort way to try to identify the root cause error. The proposal basically is to extend the "earliest exception" mechanism that torch distributed already uses within a local process group to multiple worker pods.)

  2. My understanding is the following: Propeller will look for the error file once the PytorchJob CR goes into a terminal state, either Failed or more often Succeeded (because flytekit catches the error). This happens once any of the worker pod terminates. It is true that propeller could list the error files "too early" while not all of the pods have written theirs. With the current behavior we have a strong preference for considering the last error that occurred as the one that killed the job (because error files of different pods overwrite each other). I think with the RFC proposal, at least we would have a preference for the earliest error (which we look for) because it's the first pod that "should fail" and upload its error file.

    That being said, this is a good point you raised and I'd like to explore how we can improve this. Propeller in theory has the knowledge how many error files we would expect (the number of workers). In theory it could do an ls on the <raw output prefix>/errors bucket: if the number of error files is already the number of workers, we proceed, otherwise we'd wait for a certain amount of time and ls again. After a certain amount of time, assume no more errors will come.
    This would make things better while it's still not a guarantee. But I also don't see what else we can do without making flytepropeller consider the pods underlying a CRD (to see if any are still running which could still upload an error file) and not only the CRD itself - which would be a waay bigger change.

  3. This means that for distributed jobs we do an "unnecessary" check if error.pb exists (unnecessary because we could know it won't) but this is far better in my eyes than doing an "unnecessary" ls for all non distributed tasks. So I think we could do this if you prefer this behavior over the currently proposed different remote file output readers/strategies in propeller.

@bgedik
Copy link
Contributor

bgedik commented Oct 23, 2024

This means that for distributed jobs we do an "unnecessary" check if error.pb exists (unnecessary because we could know it won't) but this is far better in my eyes than doing an "unnecessary" ls for all non distributed tasks. So I think we could do this if you prefer this behavior over the currently proposed different remote file output readers/strategies in propeller.

@fg91 I thought about point 3 a bit more and I think this is a bit more than just having an extra HEAD call on error.pb.

Say we looked at error.pb and it is not there. Is it not there because we are using multiple error files or is it not there because there is no error? We would then need to do a listing to answer for sure. Then we are back to doing a listing. Specifically:

For non distributed cases: If error.pb file is missing, we need to have additional information on this being non distributed inside the remote output reader to avoid the listing, as lack of the file is not sufficient proof that we have no error.

For distributed cases: If error.pb file is missing, we need to have additional information on whether distributed error reporting is enabled or not inside the remote output reader to avoid an unnecessary listing.

Overall, having a strategy seems cleaner to me. May be we should discuss a bit what the downside is with the strategy approach.

@davidmirror-ops
Copy link
Contributor

10/24/2024 Contributor's sync notes: pending sync with Haytham to keep refining implementation approach.

@fg91 fg91 changed the title WIP: Deterministic error propagation for distributed (training) tasks Deterministic error propagation for distributed (training) tasks Nov 18, 2024
@fg91 fg91 merged commit 8e9616a into master Nov 18, 2024
48 of 50 checks passed
@fg91 fg91 deleted the fg91/rfc/distributed-jobs-error-handling branch November 18, 2024 17:53
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
rfc A label for RFC issues
Projects
Status: Accepted
Development

Successfully merging this pull request may close these issues.

6 participants